This tutorial illustrates the core visualization utilities available in Ax.
import numpy as np
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.modelbridge.cross_validation import cross_validate
from ax.plot.contour import interact_contour
from ax.plot.diagnostic import interact_cross_validation
from ax.plot.scatter import (
interact_fitted,
plot_objective_vs_constraints,
tile_fitted,
)
from ax.plot.slice import plot_slice
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.utils.notebook.plotting import render, init_notebook_plotting
init_notebook_plotting()
[INFO 08-09 10:10:37] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
The vizualizations require an experiment object and a model fit on the evaluated data. The routine below is a copy of the Service API tutorial, so the explanation here is omitted. Retrieving the experiment and model objects for each API paradigm is shown in the respective tutorials
noise_sd = 0.1
param_names = [f"x{i+1}" for i in range(6)] # x1, x2, ..., x6
def noisy_hartmann_evaluation_function(parameterization):
x = np.array([parameterization.get(p_name) for p_name in param_names])
noise1, noise2 = np.random.normal(0, noise_sd, 2)
return {
"hartmann6": (hartmann6(x) + noise1, noise_sd),
"l2norm": (np.sqrt((x**2).sum()) + noise2, noise_sd),
}
ax_client = AxClient()
ax_client.create_experiment(
name="test_visualizations",
parameters=[
{
"name": p_name,
"type": "range",
"bounds": [0.0, 1.0],
}
for p_name in param_names
],
objectives={"hartmann6": ObjectiveProperties(minimize=True)},
outcome_constraints=["l2norm <= 1.25"],
)
[INFO 08-09 10:10:37] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 08-09 10:10:37] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 10:10:37] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 10:10:37] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x3. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 10:10:37] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x4. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 10:10:37] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x5. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 10:10:37] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x6. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 10:10:37] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x3', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x4', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x5', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x6', parameter_type=FLOAT, range=[0.0, 1.0])], parameter_constraints=[]).
[INFO 08-09 10:10:37] ax.modelbridge.dispatch_utils: Using Models.GPEI since there are more ordered parameters than there are categories for the unordered categorical parameters.
[INFO 08-09 10:10:37] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=6 num_trials=None use_batch_trials=False
[INFO 08-09 10:10:37] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=12
[INFO 08-09 10:10:37] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=12
[INFO 08-09 10:10:37] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 12 trials, GPEI for subsequent trials]). Iterations after 12 will take longer to generate due to model-fitting.
for i in range(20):
parameters, trial_index = ax_client.get_next_trial()
# Local evaluation here can be replaced with deployment to external system.
ax_client.complete_trial(
trial_index=trial_index, raw_data=noisy_hartmann_evaluation_function(parameters)
)
[INFO 08-09 10:10:37] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 0.15074, 'x2': 0.0842, 'x3': 0.016143, 'x4': 0.800792, 'x5': 0.353366, 'x6': 0.915859}.
[INFO 08-09 10:10:37] ax.service.ax_client: Completed trial 0 with data: {'hartmann6': (-0.227969, 0.1), 'l2norm': (1.187052, 0.1)}.
[INFO 08-09 10:10:37] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 0.068763, 'x2': 0.879727, 'x3': 0.347767, 'x4': 0.199834, 'x5': 0.052906, 'x6': 0.03542}.
[INFO 08-09 10:10:37] ax.service.ax_client: Completed trial 1 with data: {'hartmann6': (-0.214187, 0.1), 'l2norm': (0.941789, 0.1)}.
[INFO 08-09 10:10:37] ax.service.ax_client: Generated new trial 2 with parameters {'x1': 0.769331, 'x2': 0.736825, 'x3': 0.99994, 'x4': 0.442988, 'x5': 0.011293, 'x6': 0.906881}.
[INFO 08-09 10:10:37] ax.service.ax_client: Completed trial 2 with data: {'hartmann6': (-0.019162, 0.1), 'l2norm': (1.761496, 0.1)}.
[INFO 08-09 10:10:37] ax.service.ax_client: Generated new trial 3 with parameters {'x1': 0.945698, 'x2': 0.623851, 'x3': 0.449955, 'x4': 0.664283, 'x5': 0.537995, 'x6': 0.535811}.
[INFO 08-09 10:10:37] ax.service.ax_client: Completed trial 3 with data: {'hartmann6': (-0.053037, 0.1), 'l2norm': (1.794504, 0.1)}.
[INFO 08-09 10:10:37] ax.service.ax_client: Generated new trial 4 with parameters {'x1': 0.106036, 'x2': 0.264647, 'x3': 0.288887, 'x4': 0.429495, 'x5': 0.863059, 'x6': 0.150376}.
[INFO 08-09 10:10:37] ax.service.ax_client: Completed trial 4 with data: {'hartmann6': (-0.035371, 0.1), 'l2norm': (1.125989, 0.1)}.
[INFO 08-09 10:10:37] ax.service.ax_client: Generated new trial 5 with parameters {'x1': 0.810396, 'x2': 0.838967, 'x3': 0.97894, 'x4': 0.70397, 'x5': 0.791562, 'x6': 0.665369}.
[INFO 08-09 10:10:37] ax.service.ax_client: Completed trial 5 with data: {'hartmann6': (-0.030889, 0.1), 'l2norm': (2.043111, 0.1)}.
[INFO 08-09 10:10:37] ax.service.ax_client: Generated new trial 6 with parameters {'x1': 0.588287, 'x2': 0.439574, 'x3': 0.291996, 'x4': 0.147348, 'x5': 0.706596, 'x6': 0.007011}.
[INFO 08-09 10:10:37] ax.service.ax_client: Completed trial 6 with data: {'hartmann6': (-0.290254, 0.1), 'l2norm': (0.890949, 0.1)}.
[INFO 08-09 10:10:37] ax.service.ax_client: Generated new trial 7 with parameters {'x1': 0.432433, 'x2': 0.068439, 'x3': 0.722218, 'x4': 0.96866, 'x5': 0.571829, 'x6': 0.005293}.
[INFO 08-09 10:10:37] ax.service.ax_client: Completed trial 7 with data: {'hartmann6': (0.024098, 0.1), 'l2norm': (1.30791, 0.1)}.
[INFO 08-09 10:10:37] ax.service.ax_client: Generated new trial 8 with parameters {'x1': 0.236009, 'x2': 0.708099, 'x3': 0.296655, 'x4': 0.109109, 'x5': 0.171306, 'x6': 0.482572}.
[INFO 08-09 10:10:37] ax.service.ax_client: Completed trial 8 with data: {'hartmann6': (-0.389625, 0.1), 'l2norm': (0.878258, 0.1)}.
[INFO 08-09 10:10:38] ax.service.ax_client: Generated new trial 9 with parameters {'x1': 0.258013, 'x2': 0.234328, 'x3': 0.684539, 'x4': 0.277323, 'x5': 0.050236, 'x6': 0.482629}.
[INFO 08-09 10:10:38] ax.service.ax_client: Completed trial 9 with data: {'hartmann6': (-0.753622, 0.1), 'l2norm': (0.966709, 0.1)}.
[INFO 08-09 10:10:38] ax.service.ax_client: Generated new trial 10 with parameters {'x1': 0.272263, 'x2': 0.521808, 'x3': 0.953695, 'x4': 0.439802, 'x5': 0.13975, 'x6': 0.843858}.
[INFO 08-09 10:10:38] ax.service.ax_client: Completed trial 10 with data: {'hartmann6': (-0.951838, 0.1), 'l2norm': (1.433137, 0.1)}.
[INFO 08-09 10:10:38] ax.service.ax_client: Generated new trial 11 with parameters {'x1': 0.938858, 'x2': 0.479084, 'x3': 0.276589, 'x4': 0.2547, 'x5': 0.099248, 'x6': 0.553535}.
[INFO 08-09 10:10:38] ax.service.ax_client: Completed trial 11 with data: {'hartmann6': (-0.343751, 0.1), 'l2norm': (1.246488, 0.1)}.
[INFO 08-09 10:10:46] ax.service.ax_client: Generated new trial 12 with parameters {'x1': 0.251435, 'x2': 0.380041, 'x3': 0.796625, 'x4': 0.328275, 'x5': 0.099283, 'x6': 0.652016}.
[INFO 08-09 10:10:46] ax.service.ax_client: Completed trial 12 with data: {'hartmann6': (-1.055495, 0.1), 'l2norm': (1.120871, 0.1)}.
[INFO 08-09 10:11:05] ax.service.ax_client: Generated new trial 13 with parameters {'x1': 0.191108, 'x2': 0.412332, 'x3': 0.8443, 'x4': 0.31151, 'x5': 0.096408, 'x6': 0.730741}.
[INFO 08-09 10:11:05] ax.service.ax_client: Completed trial 13 with data: {'hartmann6': (-1.27532, 0.1), 'l2norm': (1.218943, 0.1)}.
[INFO 08-09 10:11:27] ax.service.ax_client: Generated new trial 14 with parameters {'x1': 0.132661, 'x2': 0.446646, 'x3': 0.806017, 'x4': 0.287188, 'x5': 0.099783, 'x6': 0.671705}.
[INFO 08-09 10:11:27] ax.service.ax_client: Completed trial 14 with data: {'hartmann6': (-1.052546, 0.1), 'l2norm': (0.996876, 0.1)}.
[INFO 08-09 10:11:42] ax.service.ax_client: Generated new trial 15 with parameters {'x1': 0.173732, 'x2': 0.348243, 'x3': 0.902514, 'x4': 0.305257, 'x5': 0.058143, 'x6': 0.757746}.
[INFO 08-09 10:11:42] ax.service.ax_client: Completed trial 15 with data: {'hartmann6': (-0.974881, 0.1), 'l2norm': (1.16557, 0.1)}.
[INFO 08-09 10:11:55] ax.service.ax_client: Generated new trial 16 with parameters {'x1': 0.209563, 'x2': 0.443826, 'x3': 0.792923, 'x4': 0.287932, 'x5': 0.150769, 'x6': 0.761644}.
[INFO 08-09 10:11:55] ax.service.ax_client: Completed trial 16 with data: {'hartmann6': (-1.64282, 0.1), 'l2norm': (1.252676, 0.1)}.
[INFO 08-09 10:12:13] ax.service.ax_client: Generated new trial 17 with parameters {'x1': 0.205443, 'x2': 0.408384, 'x3': 0.737058, 'x4': 0.222147, 'x5': 0.192074, 'x6': 0.748437}.
[INFO 08-09 10:12:13] ax.service.ax_client: Completed trial 17 with data: {'hartmann6': (-1.975329, 0.1), 'l2norm': (1.104612, 0.1)}.
[INFO 08-09 10:12:22] ax.service.ax_client: Generated new trial 18 with parameters {'x1': 0.206531, 'x2': 0.38294, 'x3': 0.726706, 'x4': 0.195883, 'x5': 0.222523, 'x6': 0.793202}.
[INFO 08-09 10:12:22] ax.service.ax_client: Completed trial 18 with data: {'hartmann6': (-2.127045, 0.1), 'l2norm': (1.047214, 0.1)}.
[INFO 08-09 10:12:27] ax.service.ax_client: Generated new trial 19 with parameters {'x1': 0.249057, 'x2': 0.41444, 'x3': 0.710695, 'x4': 0.150958, 'x5': 0.226, 'x6': 0.835997}.
[INFO 08-09 10:12:27] ax.service.ax_client: Completed trial 19 with data: {'hartmann6': (-1.928836, 0.1), 'l2norm': (1.10234, 0.1)}.
The plot below shows the response surface for hartmann6 metric as a function of the x1, x2 parameters.
The other parameters are fixed in the middle of their respective ranges, which in this example is 0.5 for all of them.
# this could alternately be done with `ax.plot.contour.plot_contour`
render(ax_client.get_contour_plot(param_x="x1", param_y="x2", metric_name="hartmann6"))
[INFO 08-09 10:12:27] ax.service.ax_client: Retrieving contour plot with parameter 'x1' on X-axis and 'x2' on Y-axis, for metric 'hartmann6'. Remaining parameters are affixed to the middle of their range.
The plot below allows toggling between different pairs of parameters to view the contours.
model = ax_client.generation_strategy.model
render(interact_contour(model=model, metric_name="hartmann6"))
This plot illustrates the tradeoffs achievable for 2 different metrics. The plot takes the x-axis metric as input (usually the objective) and allows toggling among all other metrics for the y-axis.
This is useful to get a sense of the pareto frontier (i.e. what is the best objective value achievable for different bounds on the constraint)
render(plot_objective_vs_constraints(model, "hartmann6", rel=False))
CV plots are useful to check how well the model predictions calibrate against the actual measurements. If all points are close to the dashed line, then the model is a good predictor of the real data.
cv_results = cross_validate(model)
render(interact_cross_validation(cv_results))
Slice plots show the metric outcome as a function of one parameter while fixing the others. They serve a similar function as contour plots.
render(plot_slice(model, "x2", "hartmann6"))
Tile plots are useful for viewing the effect of each arm.
render(interact_fitted(model, rel=False))
Total runtime of script: 2 minutes, 23.41 seconds.